import numpy as np
import torch
import os
from ASVBaselineTool.feature_tools import *
import soundfile as sf
import torchaudio
import torch.nn as nn
import torchattacks

#FGSM攻击框架
class Deepfool():
    # 模型  标签  原始数据集路径  对抗样本保存的路径
    def __init__(self,model,labelPath,datasetPath,AdvsetPath,overshoot,steps=80):
        #获取label和fname
        with open(labelPath,"r") as f:
            Labels=f.readlines()
        self.AFnames=[item.split(" ")[0] for item in Labels]
        ALabels = [item.split(" ")[-1].strip() for item in Labels]
        #label meta
        self.LabelMeta={}
        for f,l in zip(self.AFnames,ALabels):
            if l == "genuine":
                l = torch.Tensor([0.0,1.0])
            elif l == "fake":
                l = torch.Tensor([1.0,0.0])
            self.LabelMeta[f]=l
        #保存的路径 和上面是同样索引
        self.model = model
        self.model.cuda()
        self.model.eval()
        self.overshoot = overshoot
        self.datasetPath=datasetPath
        #保存对抗样本路径
        self.AdvSavePath=AdvsetPath
        self.steps=steps


    def feature_select(self,feature,wav:torch.Tensor):
        if feature=="spec_2048":
            return get_SPEC_2048(wav)
        elif feature=="mfcc_40":
            return get_MFCC_40(wav)
        elif feature=="lfcc_70":
            return get_LFCC_70(wav)
        elif feature=="raw":
            wav=wav.reshape(1,-1)
            return wav

    def model_bn_to_eval(self):
        self.model.train()
        self.model.first_bn.eval()
        self.model.block0[0].bn2.eval()
        self.model.block1[0].bn1.eval()
        self.model.block1[0].bn2.eval()
        self.model.block2[0].bn1.eval()
        self.model.block2[0].bn2.eval()
        self.model.block3[0].bn1.eval()
        self.model.block3[0].bn2.eval()
        self.model.block4[0].bn1.eval()
        self.model.block4[0].bn2.eval()
        self.model.block5[0].bn1.eval()
        self.model.block5[0].bn2.eval()
        self.model.bn_before_gru.eval()

    def Attack(self,feature:str):
        num_f=len(self.AFnames)
        print("总样本：" + str(num_f))
        #攻击成功的样本的转移率
        num_as=0
        for item in self.AFnames:
            src_wav,sr=sf.read(os.path.join(self.datasetPath,item))
            src_wav=torch.Tensor(src_wav)
            if feature=="raw":
                #将bn层固定
                self.model_bn_to_eval()
            adv_wav=src_wav
            adv_feature=self.feature_select(feature,adv_wav)
            adv_feature=adv_feature.cuda()
            #计算预测概率
            device = 'cuda' if torch.cuda.is_available() else 'cpu'
            batch_x = self.model(adv_feature)
            init_pred = torch.max(batch_x,dim=1)[1].item()#计算最大预测
            #获取原始分类
            y=self.LabelMeta[item]
            targety = -1 * y + 1
            targety = targety.cuda()
            targety = targety.reshape(1, 2)
            init_y = y.max(0, keepdim=True)[1].item()
            #如果等于就说明本来就分类错误，直接跳过
            if init_pred != init_y:
                print("原始分类错误 pass")
                continue
            #分类正确则攻击
            else:
                perturbed_wav = self.deepfool(adv_wav,y,feature,None)
                perturbed_wav = perturbed_wav.detach().cpu()
                perturbed_wav.requires_grad=True
                temp_adv_feature=self.feature_select(feature,wav=perturbed_wav)
                temp_adv_feature=temp_adv_feature.to(device)
                temp_adv_feature.cuda()
                temp_batch_x = self.model(temp_adv_feature)
                temp_init_pred = temp_batch_x.max(1)[1].item()
                #如果攻击失败
                if temp_init_pred == init_y:
                    print("攻击失败 pass")
                    continue
                else:
                    sp=os.path.join(self.AdvSavePath, item)
                    print("攻击成功并保存值路径:"+sp)
                    num_as+=1
                    perturbed_wav = perturbed_wav.detach().cpu().numpy()
                    sf.write(sp, perturbed_wav, sr)
        print("总样本：" + str(num_f))
        print("攻击成功样本数量:"+str(num_as))
        print("攻击成功率:"+str(num_as/num_f))


    def deepfool(self, x, y,feature:str, y_target=None):
        with torch.no_grad():
            x.requires_grad = True
            adv_feature = self.feature_select(feature, x)
            adv_feature = adv_feature.cuda()
            logits = self.model(adv_feature)
            outputs = torch.argmax(logits, dim=1)
            if outputs != y.max(0, keepdim=True)[1].item():
                return x
        self.nb_classes = logits.size(-1)

        adv_x = x.clone().detach().requires_grad_()

        iteration = 0
        adv_feature = self.feature_select(feature, adv_x)
        adv_feature = adv_feature.cuda()
        logits = self.model(adv_feature)
        current = logits.max(1)[1].item()
        original = logits.max(1)[1].item()
        noise = torch.zeros(x.size()).cuda()
        w = torch.zeros(x.size()).cuda()

        while (current == original and iteration < self.steps):
            gradients_0 = torch.autograd.grad(logits[0, current], [adv_x], retain_graph=True)[0].detach()

            for k in range(self.nb_classes):
                pert = np.inf
                if k == current:
                    continue
                gradients_1 = torch.autograd.grad(logits[0, k], [adv_x], retain_graph=True)[0].detach()
                w_k = gradients_1 - gradients_0
                w_k = w_k.cuda()
                f_k = logits[0, k] - logits[0, current]
                pert_k = (torch.abs(f_k)) / torch.norm(w_k.unsqueeze(0).flatten(1), 2, -1)
                if pert_k < pert:
                    pert = pert_k
                    w = w_k
            r_i = (pert + 1e-4) * w / torch.norm(w.unsqueeze(0).flatten(1), 2, -1)
            noise += r_i.clone()
            adv_x = adv_x.cuda()
            adv_x = torch.clamp(adv_x + noise, -1, 1).requires_grad_()
            adv_x = adv_x.detach().cpu()
            adv_x.requires_grad = True
            adv_feature = self.feature_select(feature, adv_x)
            adv_feature = adv_feature.cuda()
            logits = self.model(adv_feature)
            current = logits.max(1)[1].item()
            iteration = iteration + 1

        adv_x = torch.clamp((1 + self.overshoot) * noise + x.cuda(), -1, 1)

        return adv_x
















